{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e8814b34",
   "metadata": {},
   "source": [
    "# Serverless Inference with Hugging Face's Transformers & Amazon SageMaker\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a1644f1",
   "metadata": {},
   "source": [
    "Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a [Serverless Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/serverless-endpoints.html) endpoint.\n",
    "Amazon SageMaker Serverless Inference is a new capability in SageMaker that enables you to deploy and scale ML models in a Serverless fashion. Serverless endpoints automatically launch compute resources and scale them in and out depending on traffic similar to AWS Lambda.\n",
    "Serverless Inference is ideal for workloads which have idle periods between traffic spurts and can tolerate cold starts. With a pay-per-use model, Serverless Inference is a cost-effective option if you have an infrequent or unpredictable traffic pattern. \n",
    "\n",
    "## How it works \n",
    "\n",
    "The following diagram shows the workflow of Serverless Inference and the benefits of using a serverless endpoint.\n",
    "\n",
    "![architecture](./imgs/e2e.png)\n",
    "\n",
    "When you create a serverless endpoint, SageMaker provisions and manages the compute resources for you. Then, you can make inference requests to the endpoint and receive model predictions in response. SageMaker scales the compute resources up and down as needed to handle your request traffic, and you only pay for what you use.\n",
    "\n",
    "## Limitations\n",
    "\n",
    "Memory size: 1024 MB, 2048 MB, 3072 MB, 4096 MB, 5120 MB, or 6144 MB  \n",
    "Concurrent invocations: 50 per region  \n",
    "Cold starts: ms to seconds. Can be monitored with the `ModelSetupTime` Cloudwatch Metric  \n",
    "\n",
    "\n",
    "_NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances_\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53ddbf58",
   "metadata": {},
   "source": [
    "## Development Environment and Permissions\n",
    "\n",
    "### Installation \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eafe1c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sagemaker --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e4386d9",
   "metadata": {},
   "source": [
    "### Permissions\n",
    "\n",
    "_If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1c22e8d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Couldn't call 'get_role' to get Role ARN from role name philippschmid to get Role path.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role\n",
      "sagemaker bucket: sagemaker-us-east-1-558105141721\n",
      "sagemaker session region: us-east-1\n"
     ]
    }
   ],
   "source": [
    "import sagemaker\n",
    "import boto3\n",
    "sess = sagemaker.Session()\n",
    "# sagemaker session bucket -> used for uploading data, models and logs\n",
    "# sagemaker will automatically create this bucket if it not exists\n",
    "sagemaker_session_bucket=None\n",
    "if sagemaker_session_bucket is None and sess is not None:\n",
    "    # set to default bucket if a bucket name is not given\n",
    "    sagemaker_session_bucket = sess.default_bucket()\n",
    "\n",
    "try:\n",
    "    role = sagemaker.get_execution_role()\n",
    "except ValueError:\n",
    "    iam = boto3.client('iam')\n",
    "    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
    "\n",
    "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
    "\n",
    "print(f\"sagemaker role arn: {role}\")\n",
    "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n",
    "print(f\"sagemaker session region: {sess.boto_region_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2daa0fe6",
   "metadata": {},
   "source": [
    "## Create Inference `HuggingFaceModel` for the Serverless Inference Endpoint\n",
    "\n",
    "We use the [distilbert-base-uncased-finetuned-sst-2-english](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english) model running our serverless endpoint. This model is a fine-tune checkpoint of [DistilBERT-base-uncased](https://huggingface.co/distilbert-base-uncased), fine-tuned on SST-2. This model reaches an accuracy of 91.3 on the dev set (for comparison, Bert bert-base-uncased version reaches an accuracy of 92.7).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "03a47b96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---"
     ]
    }
   ],
   "source": [
    "from sagemaker.huggingface.model import HuggingFaceModel\n",
    "from sagemaker.serverless import ServerlessInferenceConfig\n",
    "\n",
    "# Hub Model configuration. <https://huggingface.co/models>\n",
    "hub = {\n",
    "    'HF_MODEL_ID':'distilbert-base-uncased-finetuned-sst-2-english',\n",
    "    'HF_TASK':'text-classification'\n",
    "}\n",
    "\n",
    "# create Hugging Face Model Class\n",
    "huggingface_model = HuggingFaceModel(\n",
    "   env=hub,                      # configuration for loading model from Hub\n",
    "   role=role,                    # iam role with permissions to create an Endpoint\n",
    "   transformers_version=\"4.26\",  # transformers version used\n",
    "   pytorch_version=\"1.13\",        # pytorch version used\n",
    "   py_version='py39',            # python version used\n",
    ")\n",
    "\n",
    "# Specify MemorySizeInMB and MaxConcurrency in the serverless config object\n",
    "serverless_config = ServerlessInferenceConfig(\n",
    "    memory_size_in_mb=4096, max_concurrency=10,\n",
    ")\n",
    "\n",
    "# deploy the endpoint endpoint\n",
    "predictor = huggingface_model.deploy(\n",
    "    serverless_inference_config=serverless_config\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b3812f",
   "metadata": {},
   "source": [
    "## Request Serverless Inference Endpoint using the `HuggingFacePredictor`\n",
    "\n",
    "The `.deploy()` returns an `HuggingFacePredictor` object which can be used to request inference. This `HuggingFacePredictor` makes it easy to send requests to your endpoint and get the results back.\n",
    "\n",
    "_The first request might have some coldstart (2-5s)._ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51c5366b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'label': 'POSITIVE', 'score': 0.9998838901519775}]\n"
     ]
    }
   ],
   "source": [
    "data = {\n",
    "  \"inputs\": \"the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .\",\n",
    "}\n",
    "\n",
    "res = predictor.predict(data=data)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f9817a1",
   "metadata": {},
   "source": [
    "## Clean up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e6fb7b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_model()\n",
    "predictor.delete_endpoint()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f846e812",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "c281c456f1b8161c8906f4af2c08ed2c40c50136979eaae69688b01f70e9f4a9"
  },
  "kernelspec": {
   "display_name": "conda_pytorch_p39",
   "language": "python",
   "name": "conda_pytorch_p39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
